import argparse
import os
import datetime

def setup_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--local_rank", type=int, default=-1)
    parser.add_argument("--world_size", type=int, default=1)

    parser.add_argument("--data_dir", type=str,
                        default="./dataset/opendialkg/")
    parser.add_argument("--t5_model", type=str, default="t5-small")
    parser.add_argument("--train_batch_size", type=int, default=24)
    parser.add_argument("--eval_batch_size", type=int, default=64)
    parser.add_argument("--max_length", help="max length for input document",
                        default=256, type=int)
    parser.add_argument("--knowledge_length", help="max lengh for knowledge",
                        default=128, type=int)
    parser.add_argument("--max_decode_step", type=int,
                        default=128, help="maximum decode step")
    parser.add_argument('--num_train_epochs',
                        help='Number of epochs to train',
                        type=int, default=30)
    parser.add_argument("--load", action="store_true")
    parser.add_argument("--debug", action="store_true",
                        help="whether to activate debugging mode")
    parser.add_argument("--gradient_accumulation_steps", type=int, default=1)
    parser.add_argument("--save_steps", type=int, default=0)
    parser.add_argument("--output_dir", type=str, default="./save/tmp")
    parser.add_argument("--no_report", action="store_true")
    parser.add_argument("--seed", default=42, type=int,
                        help="Random state(seed)")

    parser.add_argument("--do_eval", action="store_true")
    parser.add_argument("--hidden_size", type=int, default=512)
    parser.add_argument("--fp16", action="store_true")
    parser.add_argument("--adam_epsilon", type=float, default=1e-8)

    parser.add_argument("--domain", type=str, default="opendialkg", choices=["opendialkg"])
    parser.add_argument("--use_AdamW", action="store_true")
    parser.add_argument("--learning_rate", default=1e-4, type=float)
    parser.add_argument("--lm_type", type=str, default="t5", choices=["t5"])
    parser.add_argument("--entity_embed_size", type=int, default=512)

    parser.add_argument("--knowledge", type=str,
        choices=["graph", "text", "append_checked_text", "none"], default="graph")
    parser.add_argument("--eval_frequency", type=int, default=1)
    parser.add_argument("--graph_lambda", type=int, default=10)
    parser.add_argument("--num_facts", type=int, default=3, help="The number of facts packed in the single z")
    parser.add_argument("--num_samples", type=int, default=4, help="k for marginalization")
    parser.add_argument("--use_sigmoid_score", action="store_true")
    parser.add_argument("--retriever_learning_rate", type=float, default=1e-4)
    parser.add_argument("--kfm_learning_rate", type=float, default=1e-4)
    parser.add_argument("--infomax_lambda", type=float, default=0.1)
    parser.add_argument("--use_infomax", action="store_true")
    parser.add_argument("--use_augment", action="store_true", help="Use augment for infomax")
    parser.add_argument("--use_contrastive", action="store_true")
    parser.add_argument("--use_self_contrastive", action="store_true")
    # Retrieval arguments
    parser.add_argument("--unsupervised", action="store_true")

    args = parser.parse_args()

    assert args.knowledge_length + args.max_length <= 512, "Total length is maximum 512"

    args.use_sigmoid_score = True

    print(f"args.data_dir: {args.data_dir}")
    print(f"args.use_sigmoid_score: {args.use_sigmoid_score}")
    print(f"args.unsupervised: {args.unsupervised}")
    print(f"args.knowledge_length: {args.knowledge_length}")
    print(f"args.hidden_size: {args.hidden_size}")

    basedir = "./save"
    today = datetime.datetime.now().strftime("%Y%m%d")
    args.output_dir = os.path.join(basedir, today, args.output_dir)

    # Delete this if you use WanDB
    # args.no_report = True
    args.use_AdamW = True

    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)
    print(f"SEED: {args.seed}")
    return args